{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics import classification_report,confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(178, 13)\n",
      "(178,)\n"
     ]
    }
   ],
   "source": [
    "#载入数据\n",
    "data = np.genfromtxt('wine_data.csv',delimiter=',')\n",
    "x_data = data[:,1:]\n",
    "y_data = data[:,0]\n",
    "print(x_data.shape)\n",
    "print(y_data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [],
   "source": [
    "#数据切分,30%为测试集，70%为训练集\n",
    "x_train,x_test,y_train,y_test=train_test_split(x_data,y_data,test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [],
   "source": [
    "#数据标准化\n",
    "scaler = StandardScaler()\n",
    "x_train = scaler.fit_transform(x_train)\n",
    "x_test = scaler.fit_transform(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLPClassifier(hidden_layer_sizes=(100, 50, 20), max_iter=3000)"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#构建模型，3个隐藏层，第一隐藏层100个神经元，第二隐藏层50个神经元，第三隐藏层20个神经元，训练3000周期\n",
    "mlp = MLPClassifier(hidden_layer_sizes=(100,50,20),max_iter=3000)\n",
    "#训练模型\n",
    "mlp.fit(x_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         1.0       1.00      1.00      1.00        20\n",
      "         2.0       1.00      1.00      1.00        21\n",
      "         3.0       1.00      1.00      1.00        13\n",
      "\n",
      "    accuracy                           1.00        54\n",
      "   macro avg       1.00      1.00      1.00        54\n",
      "weighted avg       1.00      1.00      1.00        54\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions = mlp.predict(x_test)\n",
    "print(classification_report(y_test,predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[20  0  0]\n",
      " [ 0 21  0]\n",
      " [ 0  0 13]]\n"
     ]
    }
   ],
   "source": [
    "print(confusion_matrix(y_test,predictions))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
